# Scaling Diffusion Transformers Efficiently via μP

This is the implementation for our paper **"Scaling Diffusion Transformers Efficiently via μP"**. In this repository, we provide the code and introduction to reproduce our experiments on DiT and PixArt-α.

<img src="assets/vis.jpg" alt="vis" style="zoom:50%;" />



## Systematic investigation for DiT-μP on ImageNet

### Implementation of DiT-μP

We use the coord_check method to verify the correctness of DiT-μP. In the following figure, all curves stay horizontal, indicating that μP is implemented correctly.

```bash
python coordcheck.py --load_base_shapes width288_d28.bsh
```

![dit_coord](assets/dit_coord.jpg)

### HP transferability

We verify that DiT under μP enjoys robust HP transferability as the following figures (Figure 3 in our paper).

![HP_transfer](assets/HP_transfer.jpg)

We describe how to reproduce Figure 3(a) in our paper.

First, we need to train some DiTs with different widths and learning rates under μP. 

```bash
cd DiT/
torchrun --master_port=${MASTER_PORT} \
--master_addr=${MASTER_ADDR} \
--nproc_per_node=8 \ #A100-80G GPU
--nnodes=1 \
--node_rank=${NODE_RANK} \
train_mup.py \
--load_base_shapes width288_d28.bsh \
--mup \
--global_batch_size 256 \
--num_heads 4 \ # {2, 4, 8}
--epochs 40 \
--loglr -10 # {-9, -10, -11, -12, -13}
```

Second, we sample 50K images from these trained DiTs.

```bash
cd DiT
torchrun --master_port=${MASTER_PORT} \
--master_addr=${MASTER_ADDR} \
--nproc_per_node=8 \ # A100-80G or V100-32G
--nnodes=1 \
--node_rank=${NODE_RANK} \
sample_ddp.py \
--load_base_shapes width288_d28.bsh \
--mup \
--num_heads 4 \ # {2, 4, 8}
--ckpt path_ckpt.pth \
--cfg_scale 1 \
--vae mse
```

Third, we evaluate the performance (e.g., FID, IS, sFID) of these DiTs.

```bash
cd DiT
python create_npz.py # get sampled_50K_images.npz
python evaluator.py \
--ref_batch path/VIRTUAL_imagenet256_labeled.npz \
--sample_batch sampled_50K_images.npz
```

Finally, we can plot the figures with these data.

### Pretrain DiT-XL-2-μP

The best learning rate searched in small models is $2^{-10}$, we then use it to pretrain the DiT-XL-2-μP.

```bash
cd DiT/
torchrun --master_port=${MASTER_PORT} \
--master_addr=${MASTER_ADDR} \
--nproc_per_node=8 \ # A100-80G
--nnodes=4 \
--node_rank=${NODE_RANK} \
train_mup.py \
--load_base_shapes width288_d28.bsh \
--mup \
--global_batch_size 256 \
--num_heads 16 \
--epochs 480 \
--loglr -10
```

To reproduce the original DiT-XL-2 pretraining, we can run

```bash
cd DiT/
torchrun --master_port=${MASTER_PORT} \
--master_addr=${MASTER_ADDR} \
--nproc_per_node=8 \ # A100-80G
--nnodes=4 \
--node_rank=${NODE_RANK} \
train_mup.py \
--global_batch_size 256 \
--num_heads 16 \
--epochs 1400
```

The sampling and evaluation are the same as before.

## Scaling PixArt-α-μP on SA-1B

### Dataset

We use the SA-1B/SAM dataset following the instructions in PixArt-α repo.

### Implementation of PixArt-α-μP

It correctness can also be verified by the coord_checkmethod.

```bash
python scripts/coordcheck.py \
--load_base_shapes L28_width288.bsh \
--config configs/pixart_config/PixArt_mup_img256_SAM_coord.py \
--work_dir output/pixelart_coordcheck
```

![pixart_coord](assets/pixart_coord.jpg)

### HP search on proxy models

To reproduce the results of base learning rate search on PixArt-α-μP proxy tasks (Table 2 in our paper), we run

```bash
cd PixArt-alpha-master/
torchrun --master_port=${MASTER_PORT} \
--master_addr=${MASTER_ADDR} \
--nproc_per_node=8 \ # A100-80G
--nnodes=1 \
--node_rank=${NODE_RANK} \
train_scripts/train.py \
--config configs/pixart_config/PixArt_mup_xl2_img256_SAM_proxy.py \
--work-dir output/search_SAM_256/loglr-10 \
--load_base_shapes L28_width288.bsh \
--loglr -10 # {-9, -10, -11, -12, -13}
```

### Pretrain PixArt-α-μP

We use the best base learning rate $2^{-10}$ to train the PixArt-α-μP.

```bash
cd PixArt-alpha-master/
torchrun --master_port=${MASTER_PORT} \
--master_addr=${MASTER_ADDR} \
--nproc_per_node=8 \ # A100-80G
--nnodes=4 \
--node_rank=${NODE_RANK} \
train_scripts/train.py \
--config configs/pixart_config/PixArt_mup_xl2_img256_SAM_target.py \
--work-dir output/pretrain_SAM_256_mup/loglr-10 \
--load_base_shapes L28_width288.bsh
```

To train the original PixArt-α, we can run

```bash
cd PixArt-alpha-master/
torchrun --master_port=${MASTER_PORT} \
--master_addr=${MASTER_ADDR} \
--nproc_per_node=8 \ # A100-80G
--nnodes=4 \
--node_rank=${NODE_RANK} \
train_scripts/train.py \
--config configs/pixart_config/PixArt_xl2_img256_SAM.py \
--work-dir output/train_SAM_256
```

### Evaluation

#### MS-COCO and MJHQ

First, to obtain FID and CLIP score on MS-COCO and MJHQ-30K dataset, we need to generate images with their prompts

```bash
python scripts/inference.py \
--config config_path.py \
--load_base_shapes L28_width288.bsh \
--model_path ckpt_path.pth \
--dataset mjhq # or mscoco
```

Second, the FID of two given image sets can be calculated by

```bash
python tools/fid.py \
--ref_dir data/mjhq/imgs \
--fake_dir sampled_imgs
```

Third, to obtain the CLIP score of sampled images, we can run

```bash
python tools/clip_score.py \
--image_dir sampled_imgs \
--save_path result.csv
```

#### GenEval

First, generate images with the prompts from GenEval

```bash
python scripts/inference_geneval.py \
--config config_path.py \
--load_base_shapes L28_width288.bsh \
--model_path ckpt_path.pth
```

Second, obtain the score

```bash
python tools/evaluate_geneval.py \
--imagedir sampled_imgs \
--outfile sampled_imgs.jsonl \
--model-path output/pretrained_models/mask2former

python tools/summary_scores.py sampled_imgs.jsonl
```
